from centralized_verification.utils import TrainingProgress


def default_epsilon_schedule(training_progress):
    return 0.1


def linear_epsilon_anneal_steps(start: float, end: float, num_steps: int):
    def eps_anneal_func(training_progress: TrainingProgress):
        lin_pos = float(training_progress.global_step_count) / num_steps
        lin_pos = (lin_pos if lin_pos >= 0 else 0) if lin_pos <= 1 else 1
        return (1 - lin_pos) * start + lin_pos * end

    return eps_anneal_func


def linear_epsilon_anneal_episodes(start: float, end: float, num_episodes: int):
    def eps_anneal_func(training_progress: TrainingProgress):
        lin_pos = float(training_progress.global_episode_count) / num_episodes
        lin_pos = (lin_pos if lin_pos >= 0 else 0) if lin_pos <= 1 else 1
        return (1 - lin_pos) * start + lin_pos * end

    return eps_anneal_func
